# 模型微调

from keras.applications import VGG16
from keras import models, optimizers
from keras import layers
import matplotlib.pyplot as plt
import os
import itertools
import cv2

import numpy as np
from sklearn.metrics import confusion_matrix

from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model


'''
weights='imagenet',  模型初始化的权重检查点
include_top=False,   指定模型最后是否包含密集连接分类器。
默认情况下，这个密集连接分类器对应于ImageNet的1000个类别。这里我们使用的就是两个类别（cats and dogs）
input_shape=(150, 150, 3)  输入到网络中的图像张量形状。
'''
# path = "../mydatas/vgg16_weights_tf_dim_ordering_tf_kernels_notop (1).h5"

conv_base = VGG16(weights='imagenet',
                  include_top=False,
                  input_shape=(150, 150, 3))

model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
print(conv_base.summary())


'''
调整最后三个卷积层，就是直到block4_pool的所有层都应该被冻结，
                而block5_conv1、block5_conv2和block_conv3这三层是可训练的
'''

# 冻结直到某一层的所有层
conv_base.trainable = True

set_trainable = False
for layer in conv_base.layers:
    if layer.name == 'block5_conv1':
        layer.trainable = True
    else:
        layer.trainable = False


# 微调模型
base_dir = "G:\\python\\keshe\\train1"
# base_dir = "G:\\python\\keshe\\train1"
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
test_dir = os.path.join(base_dir, 'test')

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'

)

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(150, 150),
    batch_size=20,
    class_mode='binary'
)

validation_generator = test_datagen.flow_from_directory(
    validation_dir,
    target_size=(150, 150),
    batch_size=20,
    class_mode='binary'
)

from keras.models import load_model
model = load_model('../data/cats_and_dogs_wei_tiao.h5')

# 评估模型
# 利用保存好的模型，直接加载进来评估模型即可。
test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(150, 150),
    batch_size=20,
    class_mode='binary'
)
test_loss, test_acc = model.evaluate_generator(test_generator, steps=50)
print('test_acc:' + str(test_acc))
print('test_loss:' + str(test_loss))

test_generator2 = test_datagen.flow_from_directory(
    test_dir,
    target_size=(150, 150),
)

def get_input_xy(src=[]):
    pre_x = []
    true_y = []

    class_indices = {'cat': 0, 'dog': 1}

    for s in src:
        input = cv2.imread(s)
        input = cv2.resize(input, (150, 150))
        input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
        pre_x.append(input)

        _, fn = os.path.split(s)
        y = class_indices.get(fn[:3])
        true_y.append(y)

    pre_x = np.array(pre_x) / 255.0

    return pre_x, true_y


def plot_sonfusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 2.0
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j], horizontalalignment='center', color='white' if cm[i, j] > thresh else 'black')

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predict label')
    plt.show()


test = os.listdir(test_dir)

images = []

# 获取每张图片的地址，并保存在列表images中
for testpath in test:
    for fn in os.listdir(os.path.join(test_dir, testpath)):
        if fn.endswith('jpg'):
            fd = os.path.join(test_dir, testpath, fn)
            images.append(fd)

# 得到规范化图片及true label
pre_x, true_y = get_input_xy(images)

print("pre_x=   "+pre_x)
print("pre_x=   "+true_y)

# 预测
pred_y = model.predict(pre_x)
pred_y = np.int64(pred_y > 0.5)

# 画混淆矩阵
confusion_mat = confusion_matrix(true_y, pred_y)
plot_sonfusion_matrix(confusion_mat, classes=range(2))

